# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation for data transfer/redistribution between topologies
`.. currentmodule : hysop.core.mpi.redistribute
See hysop.operator.redistribute.Redistribute for automatic
redistribute deployment.
* :class:`~RedistributeOperatorBase` abstract base class
* :class:`~RedistributeIntra` for topologies/operators defined
inside the same mpi communicator
* :class:`~RedistributeInter` for topologies/operators defined
on two different mpi communicator
* :class:`~RedistributeOverlap` for topologies defined
inside the same mpi parent communicator and
with a different number of processes
"""
from hashlib import sha1
import numpy as np
from hysop.constants import Backend, DirectionLabels, MemoryOrdering
from hysop.tools.htypes import check_instance, to_set, first_not_None
from hysop.tools.decorators import debug
from hysop.tools.numpywrappers import npw, slices_empty
from hysop.tools.mpi_utils import get_mpi_order
from hysop.topology.cartesian_topology import Topology, CartesianTopology, TopologyView
from hysop.topology.topology_descriptor import TopologyDescriptor
from hysop.core.mpi.topo_tools import TopoTools
from hysop.core.mpi.bridge import Bridge, BridgeOverlap, BridgeInter
from hysop.operator.base.redistribute_operator import RedistributeOperatorBase
from hysop.core.graph.computational_operator import ComputationalGraphOperator
from hysop.core.graph.graph import op_apply
from hysop import MPI, MPIParams
from hysop.parameters.scalar_parameter import ScalarParameter, TensorParameter
DEBUG_REDISTRIBUTE = 0
def _memcpy(dst, src, target_indices, source_indices, skind=None, tkind=None):
def _runtime_error():
msg = "Copy from {} to {} are not handled yet."
msg = msg.format(src.__class__, dst.__class__)
raise RuntimeError(msg)
assert src.dtype == dst.dtype
skind = src.backend.kind if skind is None else skind
tkind = dst.backend.kind if tkind is None else tkind
evt = None
if skind == Backend.HOST:
if tkind == Backend.HOST:
dst[target_indices] = src[source_indices]
elif tkind == Backend.OPENCL:
from hysop.backend.device.opencl.opencl_copy_kernel_launchers import (
OpenClCopyBufferRectLauncher,
)
knl = OpenClCopyBufferRectLauncher.from_slices(
varname="redistribute",
src=src,
dst=dst,
src_slices=source_indices,
dst_slices=target_indices,
)
evt = knl(queue=dst.default_queue)
else:
_runtime_error()
elif skind == Backend.OPENCL:
from hysop.backend.device.opencl.opencl_copy_kernel_launchers import (
OpenClCopyBufferRectLauncher,
)
if tkind == Backend.HOST:
knl = OpenClCopyBufferRectLauncher.from_slices(
varname="redistribute",
src=src,
dst=dst,
src_slices=source_indices,
dst_slices=target_indices,
)
evt = knl(queue=src.default_queue)
elif tkind == Backend.OPENCL:
assert src.backend.cl_env is dst.backend.cl_env
knl = OpenClCopyBufferRectLauncher.from_slices(
varname="redistribute",
src=src,
dst=dst,
src_slices=source_indices,
dst_slices=target_indices,
)
evt = knl(queue=src.default_queue)
else:
_runtime_error()
else:
_runtime_error()
return evt
[docs]
class RedistributeIntra(RedistributeOperatorBase):
"""Data transfer between two operators/topologies.
Source and target must:
*be CartesianTopology topologies with the same global resolution
*be defined on the same communicator
*work on the same number of mpi processes
"""
[docs]
@classmethod
def can_redistribute(cls, source_topo, target_topo, **kwds):
tin = source_topo
tout = target_topo
# source and target must be CartesianTopology topology defined on HostArrayBackend
if not isinstance(source_topo, CartesianTopology):
return False
if not isinstance(target_topo, CartesianTopology):
return False
# source and target must have the same global resolution
source_res = tin.mesh.grid_resolution
target_res = tout.mesh.grid_resolution
if not npw.array_equal(source_res, target_res):
return False
# defined on the same communicator
# and work on the same number of mpi process
if not TopoTools.compare_comm(tin.parent, tout.parent):
return False
return True
def __new__(cls, **kwds):
return super().__new__(cls, **kwds)
def __init__(self, **kwds):
"""Data transfer between two operators/topologies defined on the
same communicator
Source and target must:
*be defined on the same communicator
*work on the same number of mpi process
*work with the same global resolution
"""
# Base class initialisation
super().__init__(**kwds)
# Warning : comm from io_params will be used as
# reference for all mpi communication of this operator.
# --> rank computed in refcomm
# --> source and target must work inside refcomm
# If io_params is None, refcomm will COMM_WORLD.
[docs]
@debug
def discretize(self):
super().discretize()
# Dictionnary of discrete field to be sent and received
v = self.variable
self._vsource = {v: self.input_discrete_fields[v]}
self._vtarget = {v: self.output_discrete_fields[v]}
# we can create the bridge
ifield = self.input_discrete_fields[self.variable]
ofield = self.output_discrete_fields[self.variable]
source_topo = ifield.topology
target_topo = ofield.topology
sstate = source_topo.topology_state
tstate = target_topo.topology_state
if (
(sstate.dim != tstate.dim)
or (sstate.axes != tstate.axes)
or (sstate.memory_order != tstate.memory_order)
):
msg = "Topology state mismatch between source and target."
msg += "\nSource topology state:"
msg += str(sstate)
msg += "\nTarget topology state:"
msg += str(tstate)
raise RuntimeError(msg)
assert all(source_topo.mesh.local_resolution == ifield.resolution)
assert all(target_topo.mesh.local_resolution == ofield.resolution)
self.bridge = Bridge(
source_topo, target_topo, self.dtype, get_mpi_order(ifield.sdata)
)
self._rank = self.bridge._rank
# dictionnary which maps rank with mpi derived type
# for send operations
self._send = self.bridge.send_types()
# dictionnay which maps rank with mpi derived type
# for recieve operations
self._receive = self.bridge.recv_types()
self._has_requests = False
self.dFin = ifield
self.dFout = ofield
@op_apply
def apply(self, **kwds):
# Try different way to send vars?
# - Buffered : copy all data into a buffer and send/recv
# - Standard : one send/recv
dFin, dFout = self.dFin, self.dFout
super().apply(**kwds)
# --- Standard send/recv ---
br = self.bridge
# dictionnary which map rank/field name with a receive request
self._r_request = {}
# dictionnary which map rank/field name with a send request
self._s_request = {}
basetag = self.mpi_params.rank + 1
# Comm used for send/receive operations
# It must contains all proc. of source topo and
# target topo.
refcomm = self.bridge.comm
v = self.variable
local_evts = ()
v_name = v.name
# Deal with local copies of data
if br.has_local_inter():
dst = self._vtarget[v].sdata
src = self._vsource[v].sdata
axes = self._vtarget[v].topology_state.axes
source_indices = br.local_source_ind()
target_indices = br.local_target_ind()
evt = _memcpy(dst, src, target_indices, source_indices)
if evt is not None:
local_evts += (evt,)
# Transfers to other mpi processes
for rk in self._receive:
if rk == self._rank:
continue
recvtag = basetag * 989 + (rk + 1) * 99
mpi_type = self._receive[rk]
dst = self._vtarget[v].sdata
assert dst.backend.kind is Backend.HOST
self._r_request[v_name + str(rk)] = refcomm.Irecv(
[dst.handle, 1, mpi_type], source=rk, tag=recvtag
)
self._has_requests = True
for rk in self._send:
if rk == self._rank:
continue
sendtag = (rk + 1) * 989 + basetag * 99
mpi_type = self._send[rk]
src = self._vsource[v].sdata
assert src.backend.kind is Backend.HOST
self._s_request[v_name + str(rk)] = refcomm.Issend(
[src.handle, 1, mpi_type], dest=rk, tag=sendtag
)
self._has_requests = True
for evt in local_evts:
evt.wait()
if self._has_requests:
for rk in self._r_request:
self._r_request[rk].Wait()
for rk in self._s_request:
self._s_request[rk].Wait()
self._has_requests = False
if DEBUG_REDISTRIBUTE:
print("resolution, compute_resolution, ghosts, compute_slices")
print(
dFin.resolution,
dFin.compute_resolution,
dFin.ghosts,
dFin.compute_slices,
)
print(
dFout.resolution,
dFout.compute_resolution,
dFout.ghosts,
dFout.compute_slices,
)
print()
print("BEFORE")
dFout.print_with_ghosts()
dFout.exchange_ghosts()
if DEBUG_REDISTRIBUTE:
print("AFTER")
dFout.print_with_ghosts()
mean_in = refcomm.allreduce(
dFin.sdata[dFin.compute_slices].sum().get()
) / float(refcomm.size)
mean_out = refcomm.allreduce(
dFout.sdata[dFout.compute_slices].sum().get()
) / float(refcomm.size)
assert npw.isclose(mean_in, mean_out), f"{mean_in} != {mean_out}"
[docs]
class RedistributeInter(RedistributeOperatorBase):
"""Data transfer between two operators/topologies.
Source and target must:
*be CartesianTopology topologies with the same global resolution
*be defined on different communicators
"""
[docs]
@classmethod
def can_redistribute(cls, source_topo, target_topo, other_task_id=None, **kwds):
tin = source_topo
tout = target_topo
# source and target are defined on different tasks
# (one topology is None) or (there is two topologies on different tasks)
if not (
(
isinstance(tin, CartesianTopology)
and not isinstance(tout, CartesianTopology)
)
or (
isinstance(tout, CartesianTopology)
and not isinstance(tin, CartesianTopology)
)
):
if (tout is None and tin is None) or (
tout.mpi_params.task_id == tin.mpi_params.task_id
):
return False
# source and target must have the same global resolution
if isinstance(tout, CartesianTopology) and not isinstance(
tin, CartesianTopology
):
tout_id = tout.mpi_params.task_id
_is_source, _is_dest = False, True
_other_task = other_task_id
domain = tout.domain
other_resol = npw.zeros_like(tout.mesh.grid_resolution)
my_resol = tout.mesh.grid_resolution
elif isinstance(tin, CartesianTopology) and not isinstance(
tout, CartesianTopology
):
tin_id = tin.mpi_params.task_id
_is_source, _is_dest = True, False
_other_task = other_task_id
domain = tin.domain
other_resol = npw.zeros_like(tin.mesh.grid_resolution)
my_resol = tin.mesh.grid_resolution
elif isinstance(tout, CartesianTopology) and isinstance(tin, CartesianTopology):
tout_id = tout.mpi_params.task_id
tin_id = tin.mpi_params.task_id
_is_source, _is_dest = True, True
_other_task = other_task_id
domain = tin.domain
other_resol = tin.mesh.grid_resolution
my_resol = tout.mesh.grid_resolution
else:
return False
if domain.task_rank() == 0:
if _is_source and not _is_dest:
domain.parent_comm.send(
tin.mesh.grid_resolution,
dest=domain.task_root_in_parent(_other_task),
)
other_resol = domain.parent_comm.recv(
source=domain.task_root_in_parent(_other_task)
)
if _is_dest and not _is_source:
other_resol = domain.parent_comm.recv(
source=domain.task_root_in_parent(_other_task)
)
domain.parent_comm.send(
tout.mesh.grid_resolution,
dest=domain.task_root_in_parent(_other_task),
)
other_resol = domain.task_comm.bcast(other_resol, root=0)
if not npw.array_equal(my_resol, other_resol):
return False
return True
def __new__(cls, other_task_id=None, **kwds):
return super().__new__(cls, **kwds)
def __init__(self, mpi_params=None, other_task_id=None, **kwds):
"""
Data transfer between two operators/topologies.
Source and target must:
*be CartesianTopology topologies with the same global resolution
*be defined on different communicators
"""
if not kwds["variable"] is None:
self.fake_init = False
# Base class initialisation
super().__init__(mpi_params=mpi_params, **kwds)
self._other_task_id = other_task_id
self._synchronize(kwds["source_topo"], kwds["target_topo"])
else:
# Fake init. Should be called again later
self.fake_init = True
self.initialized = True
self.name = "TempName"
self.pretty_name = "TempName"
self.mpi_params = mpi_params
self._input_fields_to_dump = []
self._output_fields_to_dump = []
self.input_fields = {}
self.output_fields = {}
self.input_params = {}
self.output_params = {}
def _synchronize(self, tin, tout):
"""Ensure that the two redistributes are operating on the same variable"""
v = self.variable
in_name, out_name = "" if tin is None else v.name, (
"" if tout is None else v.name
)
domain = first_not_None((tin, tout)).domain
# Exchange names on root ranks first
if domain.task_rank() == 0 and in_name != out_name:
rcv_name = domain.parent_comm.sendrecv(
v.name,
sendtag=self._other_task_id,
recvtag=first_not_None((tin, tout)).mpi_params.task_id,
dest=domain.task_root_in_parent(self._other_task_id),
source=domain.task_root_in_parent(self._other_task_id),
)
in_name, out_name = (
rcv_name if _ == "" else _ for _ in (in_name, out_name)
)
# then broadcast other's names on local ranks
if not tout is None:
in_name = tout.mpi_params.comm.bcast(in_name, root=0)
if not tin is None:
out_name = tin.mpi_params.comm.bcast(out_name, root=0)
assert in_name == out_name and in_name == v.name
[docs]
def output_topology_state(self, output_field, input_topology_states):
"""
Determine a specific output discrete topology state given all
input discrete topology states.
Must be redefined to help correct computational graph generation.
By default, just return first input state if all input states are all the same.
If input_topology_states are different, raise a RuntimeError as default
behaviour. Operators altering the state of their outputs *have* to
override this method.
The state may include transposition state, memory order and more.
see hysop.topology.transposition_state.TranspositionState for the complete list.
"""
from hysop.fields.continuous_field import Field
from hysop.topology.topology import TopologyState
check_instance(output_field, Field)
check_instance(input_topology_states, dict, keys=Field, values=TopologyState)
assert output_field in self.output_fields.keys()
assert len(set(input_topology_states.keys())) == 0 or set(
input_topology_states.keys()
) == set(self.input_fields.keys())
if input_topology_states:
ref_field, _ = next(iter(input_topology_states.items()))
ref_topo = self.input_fields[ref_field]
ref_state = self.output_fields[output_field].topology_state
for ifield, istate in input_topology_states.items():
itopo = self.input_fields[ifield]
if not (
istate.dim == ref_state.dim
and istate.axes == ref_state.axes
and istate.memory_order == ref_state.memory_order
):
msg = "\nInput topology state for field {} defined on topology {} does "
msg += "not match reference input topology state {} defined on topology {} "
msg += "for operator {}.\n"
msg += (
"ComputationalGraphOperator default behaviour is to raise an error "
)
msg += "when all input states do not match exactly.\n\n"
msg += "Reference state: {}\n"
msg += "Offending state: {}\n\n"
msg += "This behaviour can be changed by overriding output_topology_state() for "
msg += "your custom operator needs."
msg = msg.format(
ifield.name,
itopo.tag,
ref_field.name,
ref_topo.tag,
self.name,
ref_state,
istate,
)
raise RuntimeError(msg)
return ref_state.copy()
[docs]
@debug
def get_field_requirements(self):
reqs = super().get_field_requirements()
for f in self.input_fields:
try:
_ = reqs.get_input_requirement(f)
except RuntimeError:
reqs.update_inputs({f: reqs.get_output_requirement(f)[1]})
for f in self.output_fields:
try:
_ = reqs.get_output_requirement(f)
except RuntimeError:
reqs.update_outputs({f: reqs.get_input_requirement(f)[1]})
# Note: We enforce here the C-order to simplify the communication. As most part
# of HySoP is in any- or c-order, this is not a big overhead (if field in
# F-order, memory reordering is likely to be already present)
for is_input, requirements in reqs.iter_requirements():
if requirements is None:
continue
(field, td, req) = requirements
req.memory_order = MemoryOrdering.C_CONTIGUOUS
return reqs
@debug
def _check_inout_topology_states(
self, ifields, itopology_states, ofields, otopology_states
):
if (ifields != itopology_states) and (ofields != otopology_states):
msg = "\nFATAL ERROR: {}::{}.handle_topologies()\n\n"
msg = msg.format(type(self).__name__, self.name)
if not ((ifields != itopology_states) and (ofields == otopology_states)):
msg += "input_topology_states fields did not match operator's input Fields.\n"
if ifields - itopology_states:
msg += (
"input_topology_states are missing the following Fields: {}\n"
)
msg = msg.format(ifields - itopology_states)
else:
msg += (
"input_topology_states is providing useless extra Fields: {}\n"
)
msg = msg.format(itopology_states - ifields)
if not ((ofields != otopology_states) and (ifields == itopology_states)):
msg += "output_topology_states fields did not match operator's output Fields.\n"
if ofields - otopology_states:
msg += (
"output_topology_states are missing the following Fields: {}\n"
)
msg = msg.format(ofields - otopology_states)
else:
msg += (
"output_topology_states is providing useless extra Fields: {}\n"
)
msg = msg.format(otopology_states - ofields)
raise RuntimeError(msg)
@debug
def _check_variables(self):
"""
Check input and output variables.
Called automatically in ComputationalGraphNode.check()
"""
try:
super()._check_variables()
except TypeError:
for (ik, iv), (ok, ov) in zip(
self.input_fields.items(), self.output_fields.items()
):
if not (
(iv is None and isinstance(ov, TopologyView))
or (ov is None and isinstance(iv, TopologyView))
):
if iv is None:
msg = "Expected a Topology instance because input topo is None but got a {}.".format(
ov.__class__
)
msg += "\nAll topologies are expected to be set after "
msg += "ComputationalGraph.get_field_requirements() has been called."
raise TypeError(msg)
if ov is None:
msg = "Expected a Topology instance because output topo is None but got a {}.".format(
iv.__class__
)
msg += "\nAll topologies are expected to be set after "
msg += "ComputationalGraph.get_field_requirements() has been called."
raise TypeError(msg)
[docs]
def discretize(self):
super().discretize()
# we can create the bridge
ifield, ofield = None, None
if self.variable in self.input_discrete_fields:
ifield = self.input_discrete_fields[self.variable]
if self.variable in self.output_discrete_fields:
ofield = self.output_discrete_fields[self.variable]
_is_source, _is_target = False, False
source_topo, target_topo, source_id = None, None, None
source_tstate, target_tstate = None, None
if ifield is not None:
_is_source = True
source_topo = ifield.topology
source_id = source_topo.mpi_params.task_id
target_id = self._other_task_id
source_tstate = (
source_topo.topology_state.dim,
source_topo.topology_state.axes,
source_topo.topology_state.memory_order,
)
if DEBUG_REDISTRIBUTE != 0:
print(
"This is a redistribute of {} from source topology {}".format(
self.variable.name, source_topo.tag
)
)
if ofield is not None:
_is_target = True
target_topo = ofield.topology
target_id = target_topo.mpi_params.task_id
source_id = self._other_task_id if source_id is None else source_id
target_tstate = (
target_topo.topology_state.dim,
target_topo.topology_state.axes,
target_topo.topology_state.memory_order,
)
if DEBUG_REDISTRIBUTE != 0:
print(
"This is a redistribute of {} to target topology {}".format(
self.variable.name, target_topo.tag
)
)
self._synchronize(source_topo, target_topo)
domain = first_not_None((source_topo, target_topo)).domain
self._source_id, self._target_id = source_id, target_id
# compute a tag from algebraic relation :
# x,y \in [0;ss-1] and Tag = y+ss*(x+ss*(HASH/ss^2)
# Therefore y=Tag%ss and x = (Tag/ss)%ss
ss = domain.parent_comm.Get_size() + 1
h = int(
sha1((self.variable.name + "RedistributeInter").encode()).hexdigest(), 16
) % (1 << 31)
basetag = ss * ss * (npw.uint32(h) / (100 * ss * ss))
self._tag = lambda x, y: npw.uint32(basetag + x * ss + y)
# Exchange on root ranks first ...
if domain.task_rank() == 0 and source_tstate != target_tstate:
rcv_tstate = domain.parent_comm.sendrecv(
first_not_None((source_tstate, target_tstate)),
sendtag=self._other_task_id,
recvtag=first_not_None((source_topo, target_topo)).mpi_params.task_id,
dest=domain.task_root_in_parent(self._other_task_id),
source=domain.task_root_in_parent(self._other_task_id),
)
source_tstate, target_tstate = (
rcv_tstate if _ == None else _ for _ in (source_tstate, target_tstate)
)
# ... then broadcast
if _is_source:
target_tstate = source_topo.mpi_params.comm.bcast(target_tstate, root=0)
if _is_target:
source_tstate = target_topo.mpi_params.comm.bcast(source_tstate, root=0)
if not (source_tstate == target_tstate):
msg = "Topology state mismatch between source and target."
msg += "\nSource topology state:"
msg += str(source_tstate)
msg += "\nTarget topology state:"
msg += str(target_tstate)
raise RuntimeError(msg)
if _is_source:
assert all(source_topo.mesh.local_resolution == ifield.resolution)
if _is_target:
assert all(target_topo.mesh.local_resolution == ofield.resolution)
# Create bridges and store comm types and indices
if not domain.tasks_overlapping(source_id, target_id) is None:
self.bridge = BridgeOverlap(
source=source_topo,
target=target_topo,
source_id=source_id,
target_id=target_id,
dtype=self.dtype,
order=get_mpi_order(first_not_None((ifield, ofield)).sdata),
)
else:
self.bridge = BridgeInter(
current=first_not_None((source_topo, target_topo)),
source_id=source_id,
target_id=target_id,
dtype=self.dtype,
order=get_mpi_order(first_not_None((ifield, ofield)).sdata),
)
# dictionary that maps the rank to the derived type needed (send if on source or recieve on target)
self._comm_types, self._comm_indices = {}, {}
if _is_source:
self._comm_types[source_id] = self.bridge.transfer_types(task_id=source_id)
self._comm_indices[source_id] = self.bridge.transfer_indices(
task_id=source_id
)
if _is_target:
self._comm_types[target_id] = self.bridge.transfer_types(task_id=target_id)
self._comm_indices[target_id] = self.bridge.transfer_indices(
task_id=target_id
)
self._has_requests = False
if DEBUG_REDISTRIBUTE != 0:
print("RedistributeInter communication indices", self._comm_indices)
print("RedistributeInter communication types", self._comm_types)
self.dFin = ifield
self.dFout = ofield
self._need_copy_before, self._need_copy_after = False, False
if ifield is not None:
if not ifield.backend.kind is Backend.HOST:
self._need_copy_before = True
self._dFin_data = ifield.backend.host_array_backend.empty_like(
ifield.buffers[0]
).handle
self._dFin_data[...] = 0.0
else:
self._dFin_data = ifield.sdata.handle
if ofield is not None:
if not ofield.backend.kind is Backend.HOST:
self._need_copy_after = True
self._dFout_data = ofield.backend.host_array_backend.empty_like(
ofield.buffers[0]
).handle
self._dFout_data[...] = 0.0
else:
self._dFout_data = ofield.sdata.handle
self._is_source = _is_source
self._is_target = _is_target
@op_apply
def apply(self, **kwds):
comm = self.bridge.comm
rank = comm.Get_rank()
types = self._comm_types
indices = self._comm_indices
dFin, dFout = self.dFin, self.dFout
# TODO : Using GPU-aware MPI would simplify the usage of _memcpy
if self._is_source:
for rk, t in types[self._source_id].items():
if self._need_copy_before:
_memcpy(
self._dFin_data,
self.dFin.sdata,
target_indices=indices[self._source_id][rk],
source_indices=indices[self._source_id][rk],
skind=Backend.OPENCL,
tkind=Backend.HOST,
)
sendtag = self._tag(rk + 1, rank + 1)
comm.Isend([self._dFin_data, 1, t], dest=rk, tag=sendtag)
if self._is_target:
for rk, t in types[self._target_id].items():
recvtag = self._tag(rank + 1, rk + 1)
comm.Recv([self._dFout_data, 1, t], source=rk, tag=recvtag)
if self._need_copy_after:
_memcpy(
self.dFout.sdata,
self._dFout_data,
target_indices=indices[self._target_id][rk],
source_indices=indices[self._target_id][rk],
skind=Backend.HOST,
tkind=Backend.OPENCL,
)
self.dFout.exchange_ghosts()
[docs]
class RedistributeInterParam(ComputationalGraphOperator):
"""parameter transfer between two operators/topologies.
Source and target must:
*be MPIParams defined on different communicators
"""
[docs]
@classmethod
def supports_mpi(cls):
return True
def __new__(
cls, parameter, source_topo, target_topo, other_task_id, domain, **kwds
):
return super().__new__(cls, **kwds)
def __init__(
self, parameter, source_topo, target_topo, other_task_id, domain, **kwds
):
"""
Communicate parameter through tasks
parameter
----------
parameter: tuple of ScalarParameter or TensorParameter
parameters to communicate
source_topo: MPIParam
target_topo: MPIParam
"""
check_instance(parameter, tuple, values=(ScalarParameter, TensorParameter))
check_instance(source_topo, MPIParams, allow_none=True)
check_instance(target_topo, MPIParams, allow_none=True)
input_fields, output_fields = {}, {}
input_params, output_params = {}, {}
assert not (source_topo is None and target_topo is None)
if not source_topo is None and source_topo.on_task:
input_params = {p: source_topo for p in parameter}
if not target_topo is None and target_topo.on_task:
output_params = {p: target_topo for p in parameter}
super().__init__(
mpi_params=first_not_None(source_topo, target_topo),
input_params=input_params,
output_params=output_params,
input_fields=input_fields,
output_fields=output_fields,
**kwds,
)
self.initialized = True
self.domain = domain
self.source_task = other_task_id if source_topo is None else source_topo.task_id
self.target_task = other_task_id if target_topo is None else target_topo.task_id
self.task_is_source = domain.is_on_task(self.source_task)
self.task_is_target = domain.is_on_task(self.target_task)
if self.task_is_source:
assert source_topo.on_task
if self.task_is_target:
assert target_topo.on_task
self.inter_comm = domain.task_intercomm(
self.target_task if self.task_is_source else self.source_task
)
if self.inter_comm.is_inter:
# Disjoint tasks with real inter-communicator
self._the_apply = self._apply_intercomm
elif self.inter_comm.is_intra:
# Overlapping tasks using an intra-communicator fron union of tasks procs
self._the_apply = self._apply_intracomm
self._all_params_by_type = {}
for p in sorted(self.parameters, key=lambda _: _.name):
if not p.dtype in self._all_params_by_type:
self._all_params_by_type[p.dtype] = []
self._all_params_by_type[p.dtype].append(p)
self._send_temp_by_type = {
t: np.zeros((len(self._all_params_by_type[t]),), dtype=t)
for t in self._all_params_by_type.keys()
}
self._recv_temp_by_type = {
t: np.zeros((len(self._all_params_by_type[t]),), dtype=t)
for t in self._all_params_by_type.keys()
}
@op_apply
def apply(self, **kwds):
self._the_apply(**kwds)
def _apply_intercomm(self, **kwds):
"""Disjoint tasks so inter-comm bcast is needed."""
for t in self._all_params_by_type.keys():
if self.task_is_source:
self._send_temp_by_type[t][...] = [
p() for p in self._all_params_by_type[t]
]
self.inter_comm.bcast(
self._send_temp_by_type[t],
root=MPI.ROOT if self.domain.task_rank() == 0 else MPI.PROC_NULL,
)
if self.task_is_target:
self._recv_temp_by_type[t] = self.inter_comm.bcast(
self._send_temp_by_type[t], root=0
)
for p, v in zip(
self._all_params_by_type[t], self._recv_temp_by_type[t]
):
p.value = v
def _apply_intracomm(self, **kwds):
"""Communicator is an intra-communicator defined as tasks' comm union.
Single broadcast is enough.
"""
for t in self._all_params_by_type.keys():
if self.task_is_source and self.domain.task_rank() == 0:
self._send_temp_by_type[t][...] = [
p() for p in self._all_params_by_type[t]
]
self._recv_temp_by_type[t] = self.inter_comm.bcast(
self._send_temp_by_type[t],
self.domain.task_root_in_parent(self.source_task),
)
if self.task_is_target:
for p, v in zip(
self._all_params_by_type[t], self._recv_temp_by_type[t]
):
p.value = v